{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2676df8c",
   "metadata": {},
   "source": [
    "# Yuan2.0-2B Lora bf16微调"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3e959377",
   "metadata": {},
   "source": [
    "# 模型下载\n",
    "使用 modelscope 中的 snapshot_download 函数下载模型，第一个参数为模型名称，参数 cache_dir 为模型的下载路径。\n",
    "\n",
    "这里可以先进入autodl平台，初始化机器对应区域的的文件存储，文件存储路径为'/root/autodl-fs'。\n",
    "该存储中的文件不会随着机器的关闭而丢失，这样可以避免模型二次下载。\n",
    "\n",
    "![autodl-fs](images/autodl-fs.png)\n",
    "\n",
    "然后运行下面代码，执行模型下载。模型大小为 4.5GB，下载大概需要 5 分钟。\n",
    "\n",
    "```python\n",
    "from modelscope import snapshot_download\n",
    "model_dir = snapshot_download('YuanLLM/Yuan2-2B-Mars-hf', cache_dir='/root/autodl-fs')\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "971eb160",
   "metadata": {},
   "source": [
    "# 导入环境"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9f62ec38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from datasets import Dataset\n",
    "from transformers import LlamaTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e098d9eb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 将JSON文件转换为CSV文件\n",
    "df = pd.read_json('../dataset/huanhuan-100.json')\n",
    "ds = Dataset.from_pandas(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8ac92d42-efae-49b1-a00e-ccaa75b98938",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'instruction': ['你是谁？', '你好'],\n",
       " 'input': ['', ''],\n",
       " 'output': ['我是甄嬛，家父是大理寺少卿甄远道。', '皇上好，我是甄嬛，家父是大理寺少卿甄远道。']}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds[:2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51d05e5d-d14e-4f03-92be-9a9677d41918",
   "metadata": {},
   "source": [
    "# 处理数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "74ee5a67-2e55-4974-b90e-cbf492de500a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
     ]
    }
   ],
   "source": [
    "path = '/root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf'\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(path, add_eos_token=False, add_bos_token=False, eos_token='<eod>')\n",
    "tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2503a5fa-9621-4495-9035-8e7ef6525691",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def process_func(example):\n",
    "    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token，因此需要放开一些最大长度，保证数据的完整性\n",
    "\n",
    "    instruction = tokenizer(f\"{example['instruction']}<sep>\")\n",
    "    response = tokenizer(f\"{example['output']}<eod>\")\n",
    "    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"]\n",
    "    attention_mask = [1] * len(input_ids) \n",
    "    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] # instruction 不计算loss\n",
    "\n",
    "    if len(input_ids) > MAX_LENGTH:  # 做一个截断\n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"labels\": labels\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "84f870d6-73a9-4b0f-8abf-687b32224ad8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eb65f900df62429fb3e6af1d1e972ae1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/100 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 100\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
    "tokenized_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1f7e15a0-4d9a-4935-9861-00cc472654b1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' 你是谁？<sep> 我是甄嬛，家父是大理寺少卿甄远道。<eod>'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_id[0]['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "97f16f66-324a-454f-8cc3-ef23b100ecff",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' 我是甄嬛，家父是大理寺少卿甄远道。<eod>'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[0][\"labels\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "424823a8-ed0d-4309-83c8-3f6b1cdf274c",
   "metadata": {},
   "source": [
    "# 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "170764e5-d899-4ef4-8c53-36f6dec0d198",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "YuanForCausalLM(\n",
       "  (model): YuanModel(\n",
       "    (embed_tokens): Embedding(135040, 2048, padding_idx=77185)\n",
       "    (layers): ModuleList(\n",
       "      (0-23): 24 x YuanDecoderLayer(\n",
       "        (self_attn): YuanAttention(\n",
       "          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (rotary_emb): YuanRotaryEmbedding()\n",
       "          (lf_gate): LocalizedFiltering(\n",
       "            (conv1): Conv2d(2048, 1024, kernel_size=(2, 1), stride=(1, 1), padding=(1, 0))\n",
       "            (conv2): Conv2d(1024, 2048, kernel_size=(2, 1), stride=(1, 1), padding=(1, 0))\n",
       "            (output_layernorm): YuanRMSNorm()\n",
       "          )\n",
       "          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "        )\n",
       "        (mlp): YuanMLP(\n",
       "          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)\n",
       "          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)\n",
       "          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): YuanRMSNorm()\n",
       "        (post_attention_layernorm): YuanRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): YuanRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=2048, out_features=135040, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(path, device_map=\"auto\", torch_dtype=torch.bfloat16, trust_remote_code=True)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2323eac7-37d5-4288-8bc5-79fac7113402",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.enable_input_require_grads() # 开启梯度检查点时，要执行该方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f808b05c-f2cb-48cf-a80d-0c42be6051c7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.bfloat16"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13d71257-3c1c-4303-8ff8-af161ebc2cf1",
   "metadata": {},
   "source": [
    "# lora "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2d304ae2-ab60-4080-a80d-19cac2e3ade3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'o_proj', 'up_proj', 'down_proj', 'gate_proj', 'v_proj', 'k_proj', 'q_proj'}, lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
    "    inference_mode=False, # 训练模式\n",
    "    r=8, # Lora 秩\n",
    "    lora_alpha=32, # Lora alaph，具体作用参见 Lora 原理\n",
    "    lora_dropout=0.1# Dropout 比例\n",
    ")\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2c2489c5-eaab-4e1f-b06a-c3f914b4bf8e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='/root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf', revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'o_proj', 'up_proj', 'down_proj', 'gate_proj', 'v_proj', 'k_proj', 'q_proj'}, lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = get_peft_model(model, config)\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ebf5482b-fab9-4eb3-ad88-c116def4be12",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 9,043,968 || all params: 2,097,768,448 || trainable%: 0.4311\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "41ec69a8-8ab7-4256-9f30-803096f07fcb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.bfloat16"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca055683-837f-4865-9c57-9164ba60c00f",
   "metadata": {},
   "source": [
    "# 配置训练参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7e76bbff-15fd-4995-a61d-8364dc5e9ea0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"./output/Yuan2.0-2B_lora_bf16\",\n",
    "    per_device_train_batch_size=4,\n",
    "    gradient_accumulation_steps=1,\n",
    "    logging_steps=10,\n",
    "    num_train_epochs=3,\n",
    "    save_steps=10, # 为了快速演示，这里设置10，建议你设置成100\n",
    "    learning_rate=5e-5,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f142cb9c-ad99-48e6-ba86-6df198f9ed96",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=tokenized_id,\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "aec9bc36-b297-45af-99e1-d4c4d82be081",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='75' max='75' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [75/75 00:20, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>3.307000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.857400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.062400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.000200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=75, training_loss=0.5636235364278157, metrics={'train_runtime': 21.937, 'train_samples_per_second': 13.676, 'train_steps_per_second': 3.419, 'total_flos': 71857524768768.0, 'train_loss': 0.5636235364278157, 'epoch': 3.0})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8abb2327-458e-4e96-ac98-2141b5b97c8e",
   "metadata": {},
   "source": [
    "# 合并加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "bd2a415a-a9ad-49ea-877f-243558a83bfc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 我是甄嬛，家父是大理寺少卿甄远道。<eod>\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, LlamaTokenizer\n",
    "import torch\n",
    "from peft import PeftModel\n",
    "\n",
    "path = '/root/autodl-fs/YuanLLM/Yuan2-2B-Mars-hf'\n",
    "lora_path = './output/Yuan2.0-2B_lora_bf16/checkpoint-70' # 这里改称你的 lora 输出对应 checkpoint 地址\n",
    "\n",
    "# 加载tokenizer\n",
    "tokenizer = LlamaTokenizer.from_pretrained(path, add_eos_token=False, add_bos_token=False, eos_token='<eod>')\n",
    "tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# 加载模型\n",
    "model = AutoModelForCausalLM.from_pretrained(path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
    "\n",
    "# 加载lora权重\n",
    "model = PeftModel.from_pretrained(model, model_id=lora_path)\n",
    "\n",
    "prompt = \"你是谁？<sep>\"\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].cuda()\n",
    "outputs = model.generate(inputs, do_sample=False, max_length=256)\n",
    "output = tokenizer.decode(outputs[0])\n",
    "print(output.split(\"<sep>\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f88a87c6-e330-430b-bd24-69f5180034c2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19 (default, Mar 20 2024, 19:58:24) \n[GCC 11.2.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "c6bd51bef5b4ff0d27ffa8191406cdd00b92260c116fd6161fe29291443fd9f8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
